#include "ast/visitor.hpp"
#include "ast/ast.hpp"
#include "runtime/dict_object.hpp"
#include "runtime/functionObject.hpp"
#include "env.hpp"
#include "type/type.hpp"

#include <stdio.h>

void Visitor::visit_node(Node* n) {
    n->accept(this);
}

#define DEF_DUMPER_VISIT(NodeType, op)  \
void Dumper::visit(NodeType* n) {       \
    printf("(" op " ");                 \
    visit_node(n->left());              \
    printf(" ");                        \
    visit_node(n->right());             \
    printf(")");                        \
}

DEF_DUMPER_VISIT(LogicOrNode, "||");
DEF_DUMPER_VISIT(LogicAndNode, "&&");

void Dumper:: visit(LogicNotNode* n)  {
    printf("(! ");
    visit_node(n->value());
    printf(")");
}

DEF_DUMPER_VISIT(BitOrNode, "|");
DEF_DUMPER_VISIT(BitXorNode, "^");
DEF_DUMPER_VISIT(BitAndNode, "&");
DEF_DUMPER_VISIT(LeftShiftNode, "<<");
DEF_DUMPER_VISIT(RightShiftNode, ">>");

DEF_DUMPER_VISIT(AddNode, "+");
DEF_DUMPER_VISIT(SubNode, "-");
DEF_DUMPER_VISIT(MulNode, "*");
DEF_DUMPER_VISIT(DivNode, "/");

void Dumper::visit(ListNode* n) {
    printf("[");
    for (auto it = n->node_list()->begin(); it != n->node_list()->end(); it++) {
        (*it)->accept(this);
        printf(", \n");
    }
    printf("]\n");
}

void Dumper::visit(ConstInt* n) {
    printf("%d", n->value()->value());
}

void Dumper::visit(ConstBool* n) {
    if (n->value()) {
        printf("True");
    }
    else {
        printf("False");
    }
}

void Dumper::visit(ConstString* n) {
    printf("%s", n->value()->value());
}

void Dumper::visit(ConstChar* n) {
    printf("%c", n->value()->value());
}

void Dumper::visit(VarNode* n) {
    printf("%s", n->name());
}

void Dumper::visit(VarDefNode* n) {
    printf("(var ");
    printf("%s", n->name());
    if (n->type() != NULL) {
        printf(" : %s", n->type()->to_string());
    }

    if (n->init_value() != NULL) {
        printf(", ");
        n->init_value()->accept(this);
    }
    printf(")");
}

void Dumper::visit(TypeDefNode* n) {
}

void Dumper::visit(AssignNode* n) {
    printf("(= ");
    n->left()->accept(this);
    printf(" ");
    n->right()->accept(this);
    printf(")");
}

void Dumper::visit(LambdaDef* n) {
    printf("(define ");
    n->param()->accept(this);
    printf(" ");
    n->body()->accept(this);
    printf(")");
}

void Dumper::visit(CallNode* n) {
    printf("(apply ");
    n->func_name()->accept(this);
    printf(", ");
    n->param()->accept(this);
    printf(")");
}

void Dumper::visit(PrintNode* n) {
    printf("(print ");
    n->body()->accept(this);
    printf(")");
}

void Dumper::visit(PrintlnNode* n) {
    printf("(println ");
    n->body()->accept(this);
    printf(")");
}

void Dumper::visit(CmpNode* n) {
    if (n->cmp_op() == N_EQUAL) {
        printf("(== ");
    }
    else {
        return;
    }

    n->left()->accept(this);
    printf(" ");
    n->right()->accept(this);
    printf(")");
}

void Dumper::visit(IfNode* n) {
    printf("(if ");
    n->cond()->accept(this);
    printf(" then ");
    n->then_block()->accept(this);
    printf(" else ");
    n->else_block()->accept(this);
    printf(")\n");
}

// evaluate.
Evaluator::Evaluator() {
    _frame = new StackFrame();
    _frame->save_var("print", new NativeFunctionObject(sys_print));
    _frame->save_var("println", new NativeFunctionObject(sys_println));
    _result = NULL;
}

void Evaluator::visit(LogicOrNode* n) {
    n->left()->accept(this);
    if (_result == BoolObject::TrueValue) {
        return;
    }

    // Result depends on right value.
    n->right()->accept(this);
}

void Evaluator::visit(LogicAndNode* n) {
    n->left()->accept(this);
    if (_result == BoolObject::FalseValue) {
        return;
    }

    // Result depends on right value.
    n->right()->accept(this);
}

void Evaluator::visit(LogicNotNode* n) {
    n->value()->accept(this);
    if (_result == BoolObject::FalseValue) {
        _result = BoolObject::TrueValue;
    }
    else {
        _result = BoolObject::FalseValue;
    }
}

void Evaluator::visit(BitOrNode* n) {
    visit_node(n->left());
    IntObject* left = (IntObject*)_result;
    visit_node(n->right());
    IntObject* right = (IntObject*)_result;
    _result = left->bit_or(right);
}

void Evaluator::visit(BitXorNode* n) {
    visit_node(n->left());
    IntObject* left = (IntObject*)_result;
    visit_node(n->right());
    IntObject* right = (IntObject*)_result;
    _result = left->bit_xor(right);
}

void Evaluator::visit(BitAndNode* n) {
    visit_node(n->left());
    IntObject* left = (IntObject*)_result;
    visit_node(n->right());
    IntObject* right = (IntObject*)_result;
    _result = left->bit_and(right);
}

void Evaluator::visit(LeftShiftNode* n) {
    visit_node(n->left());
    IntObject* left = (IntObject*)_result;
    visit_node(n->right());
    IntObject* right = (IntObject*)_result;
    _result = left->lshift(right);
}

void Evaluator::visit(RightShiftNode* n) {
    visit_node(n->left());
    IntObject* left = (IntObject*)_result;
    visit_node(n->right());
    IntObject* right = (IntObject*)_result;
    _result = left->rshift(right);
}

void Evaluator::visit(AddNode* n) {
    visit_node(n->left());
    Object* left = _result;
    visit_node(n->right());
    _result = left->add(_result);
}

void Evaluator::visit(SubNode* n) {
    visit_node(n->left());
    if (_result->is_int()) {
        IntObject* left = (IntObject*)_result;
        visit_node(n->right());
        IntObject* right = (IntObject*)_result;
        _result = left->sub(right);
    } else if (_result->is_char()) {
        CharObject* left = (CharObject*)_result;
        visit_node(n->right());
        CharObject* right = (CharObject*)_result;
        _result = left->sub(right);
    } else {
        printf("Unexpected type for sub\n");
        exit(-1);
    }
}

void Evaluator::visit(MulNode* n) {
    visit_node(n->left());
    IntObject* left = (IntObject*)_result;
    visit_node(n->right());
    IntObject* right = (IntObject*)_result;
    _result = left->mul(right);
}

void Evaluator::visit(DivNode* n) {
    visit_node(n->left());
    IntObject* left = (IntObject*)_result;
    visit_node(n->right());
    IntObject* right = (IntObject*)_result;
    _result = left->div(right);
}

void Evaluator::visit(ListNode* n) {
    for (auto it = n->node_list()->begin(); it != n->node_list()->end(); it++) {
        (*it)->accept(this);
    }
}

void Evaluator::visit(ConstInt* n) {
    _result = n->value();
}

void Evaluator::visit(ConstBool* n) {
    if (n->value()) {
        _result = BoolObject::TrueValue;
    }
    else {
        _result = BoolObject::FalseValue;
    }
}

void Evaluator::visit(ConstString* n) {
    _result = n->value();
}

void Evaluator::visit(ConstChar* n) {
    _result = n->value();
}

void Evaluator::visit(VarNode* n) {
    Object* v = _frame->lookup(n->name());
    if (v == NULL) {
        printf("undefined variable :%s\n", n->name());
        exit(1);
    }
    _result = v;
}

void Evaluator::visit(VarDefNode* n) {
    if (n->init_value() != NULL) {
        n->init_value()->accept(this);
        _frame->save_var(n->name(), _result);
    }
    else {
        _frame->save_var(n->name(), UnitObject::UnitValue);
    }
}

void Evaluator::visit(TypeDefNode* n) {
}

void Evaluator::visit(AssignNode* n) {
    n->right()->accept(this);
    Object* t = _result;
    _frame->save_var(((VarNode*)n->left())->name(), t);
}

void Evaluator::visit(LambdaDef* n) {
    _result = new ClosureObject(n, DictObject::from_stack_frame(_frame));
}

void Evaluator::visit(CallNode* n) {
    n->func_name()->accept(this);
    Object* callable = _result;
    n->param()->accept(this);
    Object* param = _result;

    if (callable->is_native_func()) {
        std::vector<Object*> args;
        args.push_back(param);
        _result = ((NativeFunctionObject*)callable)->get_func()(args);
        return;
    }

    ClosureObject* closure = (ClosureObject*) callable;
    LambdaDef* func = closure->func_def();

    StackFrame* saved = _frame;
    _frame = new StackFrame(closure->env());
    _frame->save_var(func->param()->name(), param);
    func->body()->accept(this);
    delete _frame;
    _frame = saved;
}

void Evaluator::visit(PrintNode* n) {
    n->body()->accept(this);
    _result->print();
}

void Evaluator::visit(PrintlnNode* n) {
    n->body()->accept(this);
    _result->print();
    printf("\n");
}

void Evaluator::visit(CmpNode* n) {
    n->left()->accept(this);
    IntObject* left = (IntObject*)_result;
    n->right()->accept(this);
    IntObject* right = (IntObject*)_result;

    _result = left->equal(right) ? BoolObject::TrueValue : BoolObject::FalseValue;
}

void Evaluator::visit(IfNode* n) {
    n->cond()->accept(this);
    if (_result == BoolObject::TrueValue) {
        n->then_block()->accept(this);
    } 
    else {
        n->else_block()->accept(this);
    } 
}

Evaluator::~Evaluator() {
    if (_frame) {
        delete _frame;
        _frame = NULL;
    }

    _result = NULL;
}

